package cn.yu.spring.boot.milvus.service;

import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.PageUtil;
import cn.hutool.json.JSONUtil;
import cn.yu.spring.boot.milvus.model.CalcDistanceParam;
import cn.yu.spring.boot.milvus.model.FieldDto;
import cn.yu.spring.boot.milvus.model.SearchParamsDto;
import com.google.common.collect.Lists;
import io.milvus.client.MilvusServiceClient;
import io.milvus.grpc.*;
import io.milvus.param.IndexType;
import io.milvus.param.MetricType;
import io.milvus.param.R;
import io.milvus.param.RpcStatus;
import io.milvus.param.collection.*;
import io.milvus.param.control.GetQuerySegmentInfoParam;
import io.milvus.param.dml.DeleteParam;
import io.milvus.param.dml.InsertParam;
import io.milvus.param.dml.QueryParam;
import io.milvus.param.dml.SearchParam;
import io.milvus.param.index.CreateIndexParam;
import io.milvus.param.partition.*;
import io.milvus.response.DescCollResponseWrapper;
import io.milvus.response.GetCollStatResponseWrapper;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import java.util.List;
import java.util.concurrent.TimeUnit;

/**
 * @author zy
 * @date 2022/4/2 1:53 下午
 * @desc 描述
 **/
@Component
public class MilvusService {

    @Autowired
    MilvusServiceClient milvusServiceClient;

    public R<SearchResults> search(String collectionName, String vectorName, List<List<Float>> vectors, List<String> outFields, String expr, Integer topK) {
        SearchParam searchParam = SearchParam.newBuilder()
                .withParams("{\"nprobe\":128}")
                .withCollectionName(collectionName)
                .withTopK(topK)
                .withVectorFieldName(vectorName)
                .withVectors(vectors)
                .withMetricType(MetricType.IP)
                .withOutFields(outFields)
                .withExpr(expr)
                .build();
        return milvusServiceClient.withTimeout(20, TimeUnit.SECONDS).search(searchParam);
    }

    public R<SearchResults> search(String collectionName, List<String> partitionNames, String vectorName, List<List<Float>> vectors, List<String> outFields, String expr, Integer topK) {
        SearchParam searchParam = SearchParam.newBuilder()
                .withParams("{\"nprobe\":128}")
                .withCollectionName(collectionName)
                .withPartitionNames(partitionNames)
                .withTopK(topK)
                .withVectorFieldName(vectorName)
                .withVectors(vectors)
                .withMetricType(MetricType.IP)
                .withOutFields(outFields)
                .withExpr(expr)
                .build();
        return milvusServiceClient.withTimeout(20, TimeUnit.SECONDS).search(searchParam);
    }

    public R<SearchResults> searchPage(String collectionName, String vectorName, List<List<Float>> vectors, List<String> outFields, String expr, Integer pageNo, Integer pageSize) {
        PageUtil.setOneAsFirstPageNo();
        int offset = PageUtil.getStart(pageNo, pageSize);
        SearchParam searchParam = SearchParam.newBuilder()
                .withParams(JSONUtil.toJsonStr(SearchParamsDto.builder().nprobe(128).offset(offset).build()))
                .withTopK(pageSize)
                .withCollectionName(collectionName)
                .withVectorFieldName(vectorName)
                .withVectors(vectors)
                .withMetricType(MetricType.IP)
                .withOutFields(outFields)
                .withExpr(expr)
                .build();
        return milvusServiceClient.withTimeout(20, TimeUnit.SECONDS).search(searchParam);
    }

    public R<SearchResults> searchPage(String collectionName, List<String> partitionNames, String vectorName, List<List<Float>> vectors, List<String> outFields, String expr, Integer pageNo, Integer pageSize) {
        PageUtil.setOneAsFirstPageNo();
        int offset = PageUtil.getStart(pageNo, pageSize);
        SearchParam searchParam = SearchParam.newBuilder()
                .withParams(JSONUtil.toJsonStr(SearchParamsDto.builder().nprobe(128).offset(offset).build()))
                .withTopK(pageSize)
                .withCollectionName(collectionName)
                .withPartitionNames(partitionNames)
                .withVectorFieldName(vectorName)
                .withVectors(vectors)
                .withMetricType(MetricType.IP)
                .withOutFields(outFields)
                .withExpr(expr)
                .build();
        return milvusServiceClient.withTimeout(20, TimeUnit.SECONDS).search(searchParam);
    }

    public List<Float> calcDistance(List<List<Float>> vectorsLeft, List<List<Float>> vectorsRight) {
        List<Float> distanceList = Lists.newArrayList();
        vectorsRight.forEach(floats -> {
            List<CalcDistanceParam> calcDistanceParams = Lists.newArrayList();
            for (int i = 0; i < floats.size(); i++) {
                calcDistanceParams.add(CalcDistanceParam.builder().vectorsLeft(vectorsLeft.get(0).get(i)).vectorsRight(floats.get(i)).build());
            }
            Double distance = calcDistanceParams.stream().reduce(0.00, (x, y) -> x + (y.getVectorsLeft() * y.getVectorsRight()), Double::sum);
            distanceList.add(distance.floatValue());
        });
        return distanceList;
    }

    public R<RpcStatus> loadCollection(String collectionName) {
        return milvusServiceClient.withTimeout(20, TimeUnit.SECONDS)
                .loadCollection(LoadCollectionParam.newBuilder()
                        .withCollectionName(collectionName)
                        .build());
    }

    public R<RpcStatus> releaseCollection(String collectionName) {
        return milvusServiceClient.withTimeout(20, TimeUnit.SECONDS)
                .releaseCollection(ReleaseCollectionParam.newBuilder()
                        .withCollectionName(collectionName)
                        .build());
    }

    public R<RpcStatus> loadPartitions(String collectionName, List<String> partitionNames) {
        return milvusServiceClient.loadPartitions(LoadPartitionsParam.newBuilder()
                .withCollectionName(collectionName)
                .withPartitionNames(partitionNames)
                .build());
    }

    public R<RpcStatus> releasePartitions(String collectionName, List<String> partitionNames) {
        return milvusServiceClient.releasePartitions(ReleasePartitionsParam.newBuilder()
                .withCollectionName(collectionName)
                .withPartitionNames(partitionNames)
                .build());
    }

    public R<MutationResult> insert(String collectionName, List<FieldDto> fieldDtoList) {
        List<InsertParam.Field> insertParamList = Lists.newArrayList();
        fieldDtoList.forEach(fieldDto -> insertParamList.add(new InsertParam.Field(fieldDto.getName(), fieldDto.getValues())));
        InsertParam insertParam = InsertParam.newBuilder().withCollectionName(collectionName).withFields(insertParamList).build();
        return milvusServiceClient.withTimeout(20, TimeUnit.SECONDS).insert(insertParam);
    }

    public R<MutationResult> insert(String collectionName, String partitionName, List<FieldDto> fieldDtoList) {
        List<InsertParam.Field> insertParamList = Lists.newArrayList();
        fieldDtoList.forEach(fieldDto -> insertParamList.add(new InsertParam.Field(fieldDto.getName(), fieldDto.getValues())));
        InsertParam insertParam = InsertParam.newBuilder().withCollectionName(collectionName).withPartitionName(partitionName).withFields(insertParamList).build();
        return milvusServiceClient.insert(insertParam);
    }

    public R<QueryResults> queryByExpr(String collectionName, String expr, List<String> outFields) {
        QueryParam queryParam = QueryParam.newBuilder().withCollectionName(collectionName).withExpr(expr).withOutFields(outFields).build();
        return milvusServiceClient.withTimeout(20, TimeUnit.SECONDS).query(queryParam);
    }

    public R<QueryResults> queryByExpr(String collectionName, List<String> partitionNames, String expr, List<String> outFields) {
        QueryParam queryParam = QueryParam.newBuilder().withCollectionName(collectionName).withPartitionNames(partitionNames).withExpr(expr).withOutFields(outFields).build();
        return milvusServiceClient.query(queryParam);
    }

    public R<MutationResult> deleteByExpr(String collectionName, String expr) {
        DeleteParam deleteParam = DeleteParam.newBuilder().withCollectionName(collectionName).withExpr(expr).build();
        return milvusServiceClient.withTimeout(20, TimeUnit.SECONDS).delete(deleteParam);
    }

    public R<MutationResult> deleteByExpr(String collectionName, String partitionName, String expr) {
        DeleteParam deleteParam = DeleteParam.newBuilder().withCollectionName(collectionName).withPartitionName(partitionName).withExpr(expr).build();
        return milvusServiceClient.delete(deleteParam);
    }

    public R<RpcStatus> createCollection(String collectionName, String description, List<FieldType> fieldTypeList) {
        CreateCollectionParam createCollectionParam = CreateCollectionParam.newBuilder().withCollectionName(collectionName).withDescription(description).withFieldTypes(fieldTypeList).build();
        return milvusServiceClient.withTimeout(20, TimeUnit.SECONDS).createCollection(createCollectionParam);
    }

    public R<RpcStatus> dropCollection(String collectionName) {
        DropCollectionParam dropCollectionParam = DropCollectionParam.newBuilder().withCollectionName(collectionName).build();
        return milvusServiceClient.withTimeout(20, TimeUnit.SECONDS).dropCollection(dropCollectionParam);
    }

    public R<Boolean> hasCollection(String collectionName) {
        HasCollectionParam hasCollectionParam = HasCollectionParam.newBuilder().withCollectionName(collectionName).build();
        return milvusServiceClient.withTimeout(20, TimeUnit.SECONDS).hasCollection(hasCollectionParam);
    }

    public R<RpcStatus> createIndex(String collectionName, String fieldName) {
        CreateIndexParam createIndexParam = CreateIndexParam.newBuilder().withCollectionName(collectionName).withFieldName(fieldName).withIndexType(IndexType.IVF_FLAT).withMetricType(MetricType.IP).withExtraParam("{\"nlist\":4096}").build();
        return milvusServiceClient.withTimeout(20, TimeUnit.SECONDS).createIndex(createIndexParam);
    }

    public R<RpcStatus> createPartition(String collectionName, String partitionName) {
        CreatePartitionParam createPartitionParam = CreatePartitionParam.newBuilder().withCollectionName(collectionName).withPartitionName(partitionName).build();
        return milvusServiceClient.withTimeout(20, TimeUnit.SECONDS).createPartition(createPartitionParam);
    }

    public R<Boolean> hasPartition(String collectionName, String partitionName) {
        HasPartitionParam hasPartitionParam = HasPartitionParam.newBuilder().withCollectionName(collectionName).withPartitionName(partitionName).build();
        return milvusServiceClient.withTimeout(20, TimeUnit.SECONDS).hasPartition(hasPartitionParam);
    }

    public R<ShowCollectionsResponse> listCollections(List<String> collectionNames) {
        io.milvus.param.collection.ShowCollectionsParam.Builder builder = ShowCollectionsParam.newBuilder();
        if (CollUtil.isNotEmpty(collectionNames)) {
            builder.withCollectionNames(collectionNames);
        }
        ShowCollectionsParam showCollectionsParam = builder.build();
        return milvusServiceClient.showCollections(showCollectionsParam);
    }

    public R<ShowPartitionsResponse> listAllPartitions(String collectionName) {
        return milvusServiceClient.showPartitions(ShowPartitionsParam.newBuilder().withCollectionName(collectionName).build());
    }

    public R<ShowPartitionsResponse> listPartitions(String collectionName, List<String> partitionNames) {
        io.milvus.param.partition.ShowPartitionsParam.Builder builder = ShowPartitionsParam.newBuilder().withCollectionName(collectionName);
        if (CollUtil.isNotEmpty(partitionNames)) {
            builder.withPartitionNames(partitionNames);
        }
        ShowPartitionsParam showPartitionsParam = builder.build();
        return milvusServiceClient.showPartitions(showPartitionsParam);
    }

    public R<RpcStatus> dropPartition(String collectionName, String partitionName) {
        DropPartitionParam dropPartitionParam = DropPartitionParam.newBuilder().withCollectionName(collectionName).withPartitionName(partitionName).build();
        return milvusServiceClient.dropPartition(dropPartitionParam);
    }

    public R<FlushResponse> flush(List<String> collectionNames) {
        FlushParam flushParam = FlushParam.newBuilder().withCollectionNames(collectionNames).build();
        return milvusServiceClient.withTimeout(20, TimeUnit.SECONDS).flush(flushParam);
    }

    public DescCollResponseWrapper describeCollection(String collectionName) {
        R<DescribeCollectionResponse> response = milvusServiceClient.describeCollection(DescribeCollectionParam.newBuilder().withCollectionName(collectionName).build());
        return new DescCollResponseWrapper(response.getData());

    }

    public GetCollStatResponseWrapper getCollectionStatistics(String collectionName) {
        R<GetCollectionStatisticsResponse> response = milvusServiceClient.getCollectionStatistics(GetCollectionStatisticsParam.newBuilder().withCollectionName(collectionName).build());
        return new GetCollStatResponseWrapper(response.getData());
    }

    public R<GetQuerySegmentInfoResponse> getQuerySegmentInfo(String collectionName) {
        return milvusServiceClient.getQuerySegmentInfo(GetQuerySegmentInfoParam.newBuilder().withCollectionName(collectionName).build());
    }

}
